{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CNN-static implementation with Keras & Gensim\n",
    "- A model with pre-trained vectors from word2vec. \n",
    "- All words including the unknown ones that are randomly initialized are kept static and only the other parameters of the model are learned.\n",
    "- The dataset could be downloaded at: http://ai.stanford.edu/~amaas/data/sentiment/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import re\n",
    "import random\n",
    "\n",
    "from gensim.models import Word2Vec\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV, RandomizedSearchCV\n",
    "from sklearn.metrics import accuracy_score\n",
    "from multiprocessing import Pool\n",
    "from nltk.corpus import stopwords\n",
    "\n",
    "from keras.models import Sequential, Model\n",
    "from keras.layers import Dense, Dropout, Flatten, Activation, concatenate, Input\n",
    "from keras.layers.convolutional import Conv2D, MaxPooling2D\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras import backend as K\n",
    "from keras import optimizers\n",
    "from keras.wrappers.scikit_learn import KerasClassifier\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "from keras.constraints import maxnorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "stop_words = set(stopwords.words('english'))    # english stopwords from nltk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# lists to contain reviews and labels (pos/neg)\n",
    "review_list = []\n",
    "labels_list = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_instances = 5000   # number of instances to consider. Actual number of whole data instances would be (num_instances * 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# reading text files. \n",
    "files = os.listdir('aclImdb/train/pos')[:num_instances]\n",
    "for file in files:\n",
    "    with open('aclImdb/train/pos/{}'.format(file), 'r', encoding = 'utf-8') as f:\n",
    "        sentence = [word for word in f.read().split() if word not in stop_words]\n",
    "        sentence = [word.lower() for word in sentence if re.match('^[a-zA-Z]+', word)]\n",
    "        f.close()\n",
    "    review_list.append(sentence)\n",
    "    labels_list.append(1)\n",
    "\n",
    "files = os.listdir('aclImdb/train/neg')[:num_instances]\n",
    "for file in files:\n",
    "    review = ''\n",
    "    with open('aclImdb/train/neg/{}'.format(file), 'r', encoding = 'utf-8') as f:\n",
    "        sentence = [word for word in f.read().split() if word not in stop_words]\n",
    "        sentence = [word.lower() for word in sentence if re.match('^[a-zA-Z]+', word)]\n",
    "        f.close()\n",
    "    review_list.append(sentence)\n",
    "    labels_list.append(0)\n",
    "\n",
    "files = os.listdir('aclImdb/test/pos')[:num_instances]\n",
    "for file in files:\n",
    "    review = ''\n",
    "    with open('aclImdb/test/pos/{}'.format(file), 'r', encoding = 'utf-8') as f:\n",
    "        sentence = [word for word in f.read().split() if word not in stop_words]\n",
    "        sentence = [word.lower() for word in sentence if re.match('^[a-zA-Z]+', word)]\n",
    "        f.close()\n",
    "    review_list.append(sentence)\n",
    "    labels_list.append(1)\n",
    "\n",
    "files = os.listdir('aclImdb/test/neg')[:num_instances]\n",
    "for file in files:\n",
    "    review = ''\n",
    "    with open('aclImdb/test/neg/{}'.format(file), 'r', encoding = 'utf-8') as f:\n",
    "        sentence = [word for word in f.read().split() if word not in stop_words]\n",
    "        sentence = [word.lower() for word in sentence if re.match('^[a-zA-Z]+', word)]\n",
    "        f.close()\n",
    "    review_list.append(sentence)\n",
    "    labels_list.append(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20000\n",
      "20000\n"
     ]
    }
   ],
   "source": [
    "print(len(review_list))\n",
    "print(len(labels_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1508\n"
     ]
    }
   ],
   "source": [
    "max_len = len(max(review_list, key = len))    # maximum length of a sentence\n",
    "print(max_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "threshold = max_len = 500   # in order to cut out excessively long sentences, define a thresold (i.e., a maximum number of words to be considered)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# if the length of the sentence is longer than the threshold, exclude that sentence\n",
    "for i in range(len(review_list)):\n",
    "    if len(review_list[i]) > threshold :\n",
    "        review_list[i] = None\n",
    "        labels_list[i] = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "review_list = [rev for rev in review_list if rev is not None] \n",
    "labels_list = [rev for rev in labels_list if rev is not None] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19815\n",
      "19815\n"
     ]
    }
   ],
   "source": [
    "print(len(review_list))\n",
    "print(len(labels_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "embed_dim = 100    # assign the dimension of the embedding space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = Word2Vec(sentences = review_list, size = embed_dim, sg = 1, window = 5, min_count = 1, iter = 10, workers = Pool()._processes)\n",
    "model.init_sims(replace = True)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# create a 3-D numpy array to carry X data\n",
    "X_data = np.zeros((len(review_list), max_len, embed_dim))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for i in range(len(review_list)):\n",
    "    for j in range(max_len):\n",
    "        try:\n",
    "            X_data[i][j] = model[review_list[i][j]]\n",
    "        except:\n",
    "            pass   # if the word is not included in the embedding space, assign zero (i.e., zero padding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(19815, 500, 100, 1)\n"
     ]
    }
   ],
   "source": [
    "X_data = X_data.reshape(X_data.shape[0], X_data.shape[1], X_data.shape[2], 1)    # reshape the data into 4-D shape\n",
    "print(X_data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X_data, labels_list, test_size = 0.3, random_state = 777)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# assign the hyperparameters of the model\n",
    "filter_sizes = [3, 4, 5]\n",
    "dropout_rate = 0.5\n",
    "l2_constraint = 3.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def convolution():\n",
    "    inn = Input(shape = (max_len, embed_dim, 1))\n",
    "    convolutions = []\n",
    "    # we conduct three convolutions & poolings then concatenate them.\n",
    "    for fs in filter_sizes:\n",
    "        conv = Conv2D(filters = 100, kernel_size = (fs, embed_dim), strides = 1, padding = \"valid\")(inn)\n",
    "        nonlinearity = Activation('relu')(conv)\n",
    "        maxpool = MaxPooling2D(pool_size = (max_len - fs + 1, 1), padding = \"valid\")(nonlinearity)\n",
    "        convolutions.append(maxpool)\n",
    "    outt = concatenate(convolutions)\n",
    "    model = Model(input = inn, output = outt)\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def cnn_model():\n",
    "    convolutions = convolution()\n",
    "    \n",
    "    model = Sequential()\n",
    "    model.add(convolutions)\n",
    "    model.add(Dropout(dropout_rate))\n",
    "    \n",
    "    model.add(Flatten())\n",
    "    model.add(Dense(1, kernel_constraint=maxnorm(l2_constraint), activation = 'sigmoid'))\n",
    "    \n",
    "    adam = optimizers.Adam(lr = 0.01)\n",
    "    model.compile(optimizer = adam, loss = 'binary_crossentropy', metrics = ['accuracy'])\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "model_2 (Model)              (None, 1, 1, 300)         120300    \n",
      "_________________________________________________________________\n",
      "dropout_2 (Dropout)          (None, 1, 1, 300)         0         \n",
      "_________________________________________________________________\n",
      "flatten_2 (Flatten)          (None, 300)               0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 1)                 301       \n",
      "=================================================================\n",
      "Total params: 120,601\n",
      "Trainable params: 120,601\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Buomsoo\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:10: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"in..., outputs=Tensor(\"co...)`\n",
      "  # Remove the CWD from sys.path while we load stuff.\n"
     ]
    }
   ],
   "source": [
    "cnn = cnn_model()\n",
    "cnn.summary()    # summary of the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn_model = KerasClassifier(build_fn = cnn_model, epochs = 100, batch_size = 50, verbose = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Buomsoo\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:10: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"in..., outputs=Tensor(\"co...)`\n",
      "  # Remove the CWD from sys.path while we load stuff.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "13870/13870 [==============================] - 19s - loss: 0.4020 - acc: 0.8125    \n",
      "Epoch 2/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.3068 - acc: 0.8684     \n",
      "Epoch 3/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.2590 - acc: 0.8934     \n",
      "Epoch 4/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.2319 - acc: 0.9043     \n",
      "Epoch 5/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.1902 - acc: 0.9232     \n",
      "Epoch 6/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.1623 - acc: 0.9358     \n",
      "Epoch 7/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.1376 - acc: 0.9469     \n",
      "Epoch 8/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.1256 - acc: 0.9540     \n",
      "Epoch 9/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.1071 - acc: 0.9629     \n",
      "Epoch 10/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0966 - acc: 0.9650     \n",
      "Epoch 11/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0847 - acc: 0.9716     \n",
      "Epoch 12/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0845 - acc: 0.9710     \n",
      "Epoch 13/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0751 - acc: 0.9743     \n",
      "Epoch 14/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0737 - acc: 0.9746     \n",
      "Epoch 15/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0655 - acc: 0.9785     \n",
      "Epoch 16/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0611 - acc: 0.9807     \n",
      "Epoch 17/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0599 - acc: 0.9799     \n",
      "Epoch 18/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0586 - acc: 0.9822     \n",
      "Epoch 19/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0543 - acc: 0.9812     \n",
      "Epoch 20/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0581 - acc: 0.9808     \n",
      "Epoch 21/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0512 - acc: 0.9838     \n",
      "Epoch 22/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0487 - acc: 0.9851     \n",
      "Epoch 23/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0472 - acc: 0.9833     \n",
      "Epoch 24/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0413 - acc: 0.9870     \n",
      "Epoch 25/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0471 - acc: 0.9846     \n",
      "Epoch 26/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0414 - acc: 0.9873     \n",
      "Epoch 27/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0491 - acc: 0.9839     \n",
      "Epoch 28/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0443 - acc: 0.9850     \n",
      "Epoch 29/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0393 - acc: 0.9867     \n",
      "Epoch 30/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0382 - acc: 0.9871     \n",
      "Epoch 31/100\n",
      "13870/13870 [==============================] - 13s - loss: 0.0330 - acc: 0.9898    \n",
      "Epoch 32/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0411 - acc: 0.9862     \n",
      "Epoch 33/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0378 - acc: 0.9879     \n",
      "Epoch 34/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0355 - acc: 0.9878     \n",
      "Epoch 35/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0366 - acc: 0.9892     \n",
      "Epoch 36/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0360 - acc: 0.9881     \n",
      "Epoch 37/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0328 - acc: 0.9898     \n",
      "Epoch 38/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0386 - acc: 0.9880     \n",
      "Epoch 39/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0373 - acc: 0.9874     \n",
      "Epoch 40/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0356 - acc: 0.9890     \n",
      "Epoch 41/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0335 - acc: 0.9894     \n",
      "Epoch 42/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0262 - acc: 0.9925     \n",
      "Epoch 43/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0312 - acc: 0.9901     \n",
      "Epoch 44/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0316 - acc: 0.9901     \n",
      "Epoch 45/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0274 - acc: 0.9913     \n",
      "Epoch 46/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0314 - acc: 0.9892     \n",
      "Epoch 47/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0309 - acc: 0.9898     \n",
      "Epoch 48/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0309 - acc: 0.9899     \n",
      "Epoch 49/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0320 - acc: 0.9894     \n",
      "Epoch 50/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0315 - acc: 0.9907     \n",
      "Epoch 51/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0245 - acc: 0.9921     \n",
      "Epoch 52/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0259 - acc: 0.9924     \n",
      "Epoch 53/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0326 - acc: 0.9897     \n",
      "Epoch 54/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0216 - acc: 0.9934     \n",
      "Epoch 55/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0251 - acc: 0.9923     \n",
      "Epoch 56/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0260 - acc: 0.9916     \n",
      "Epoch 57/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0243 - acc: 0.9925     \n",
      "Epoch 58/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0266 - acc: 0.9916     \n",
      "Epoch 59/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0229 - acc: 0.9929     \n",
      "Epoch 60/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0235 - acc: 0.9916     \n",
      "Epoch 61/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0303 - acc: 0.9901     \n",
      "Epoch 62/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0221 - acc: 0.9925     \n",
      "Epoch 63/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0295 - acc: 0.9911     \n",
      "Epoch 64/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0274 - acc: 0.9902     \n",
      "Epoch 65/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0241 - acc: 0.9924     \n",
      "Epoch 66/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0219 - acc: 0.9934     \n",
      "Epoch 67/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0178 - acc: 0.9942     \n",
      "Epoch 68/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0236 - acc: 0.9929     \n",
      "Epoch 69/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0221 - acc: 0.9934     \n",
      "Epoch 70/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0244 - acc: 0.9927     \n",
      "Epoch 71/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0192 - acc: 0.9937     \n",
      "Epoch 72/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0234 - acc: 0.9926     \n",
      "Epoch 73/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0220 - acc: 0.9941     \n",
      "Epoch 74/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0191 - acc: 0.9942     \n",
      "Epoch 75/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0181 - acc: 0.9943     \n",
      "Epoch 76/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0196 - acc: 0.9934     \n",
      "Epoch 77/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0265 - acc: 0.9913     \n",
      "Epoch 78/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0184 - acc: 0.9938     \n",
      "Epoch 79/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0214 - acc: 0.9934     \n",
      "Epoch 80/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0218 - acc: 0.9928     \n",
      "Epoch 81/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0200 - acc: 0.9939     \n",
      "Epoch 82/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0186 - acc: 0.9950     \n",
      "Epoch 83/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0223 - acc: 0.9914     \n",
      "Epoch 84/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0169 - acc: 0.9947     \n",
      "Epoch 85/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0192 - acc: 0.9946     \n",
      "Epoch 86/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0208 - acc: 0.9931     \n",
      "Epoch 87/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0247 - acc: 0.9915     \n",
      "Epoch 88/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0245 - acc: 0.9916     \n",
      "Epoch 89/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0166 - acc: 0.9950     \n",
      "Epoch 90/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0171 - acc: 0.9947     \n",
      "Epoch 91/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0176 - acc: 0.9949     \n",
      "Epoch 92/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0154 - acc: 0.9954     \n",
      "Epoch 93/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0231 - acc: 0.9913     \n",
      "Epoch 94/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0240 - acc: 0.9924     \n",
      "Epoch 95/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0204 - acc: 0.9936     \n",
      "Epoch 96/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0182 - acc: 0.9940     \n",
      "Epoch 97/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0163 - acc: 0.9956     \n",
      "Epoch 98/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0169 - acc: 0.9954     \n",
      "Epoch 99/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0155 - acc: 0.9949     \n",
      "Epoch 100/100\n",
      "13870/13870 [==============================] - 9s - loss: 0.0141 - acc: 0.9964     \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x276956f2a58>"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cnn_model.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5850/5945 [============================>.] - ETA: 0s"
     ]
    }
   ],
   "source": [
    "y_pred = cnn_model.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.883095037847\n"
     ]
    }
   ],
   "source": [
    "print(accuracy_score(y_test, y_pred))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
